{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 安装类库\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install imgaug -t /home/aistudio/external-libraries\n",
    "import sys\n",
    "sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-03-12 08:33:16,964-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']\n",
      "2021-03-12 08:33:17,303-INFO: generated new fontManager\n",
      "Cache file /home/aistudio/.cache/paddle/dataset/cifar/cifar-100-python.tar.gz not found, downloading https://dataset.bj.bcebos.com/cifar/cifar-100-python.tar.gz \n",
      "Begin to download\n",
      "\n",
      "Download finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shape: (32, 32, 3)\n",
      "label value: cattle\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 读取数据\n",
    "reader = paddle.batch(\n",
    "    paddle.dataset.cifar.train100(),\n",
    "    batch_size=8) # 数据集读取器\n",
    "data = next(reader()) # 读取数据\n",
    "index = 0 # 批次索引\n",
    "\n",
    "# 读取图像\n",
    "image = np.array([x[0] for x in data]).astype(np.float32) # 读取图像数据，数据类型为float32\n",
    "image = image * 255 # 从[0,1]转换到[0,255]\n",
    "image = image[index].reshape((3, 32, 32)).transpose((1, 2, 0)).astype(np.uint8) # 数据格式从CHW转换为HWC，数据类型转换为uint8\n",
    "print('image shape:', image.shape)\n",
    "\n",
    "# 图像增强\n",
    "# sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "# seq = iaa.Sequential([\n",
    "#     sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "#     iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "# image = seq(image=image)\n",
    "\n",
    "# 读取标签\n",
    "label = np.array([x[1] for x in data]).astype(np.int64) # 读取标签数据，数据类型为int64\n",
    "vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "         'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "         'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "         'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "         'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "         'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "         'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "         'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "         'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "         'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "         'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "         'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "         'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "         'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "         'baby', 'boy', 'girl', 'man', 'woman',\n",
    "         'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "         'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "         'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "         'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "         'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "vlist.sort() # 字母上升排序\n",
    "print('label value:', vlist[label[index]])\n",
    "\n",
    "# 显示图像\n",
    "image = Image.fromarray(image)   # 转换图像格式\n",
    "image.save('./work/out/img.png') # 保存读取图像\n",
    "plt.figure(figsize=(3, 3))       # 设置显示大小\n",
    "plt.imshow(image)                # 设置显示图像\n",
    "plt.show()                       # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n",
      "valid_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 训练数据增强\n",
    "def train_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 增强图像\n",
    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "    seq = iaa.Sequential([\n",
    "        sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "        iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "    images = seq(images=images)\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 验证数据增强\n",
    "def valid_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 读取训练数据\n",
    "train_reader = paddle.batch(\n",
    "    paddle.reader.shuffle(paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "train_data = next(train_reader()) # 读取训练数据\n",
    "\n",
    "train_image = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取训练图像\n",
    "train_image = train_augment(train_image)                                                       # 训练图像增强\n",
    "train_label = np.array([x[1] for x in train_data]).reshape((-1, 1)).astype(np.int64)           # 读取训练标签\n",
    "print('train_data: image shape {}, label shape:{}'.format(train_image.shape, train_label.shape))\n",
    "\n",
    "# 读取验证数据\n",
    "valid_reader = paddle.batch(\n",
    "    paddle.dataset.cifar.test100(),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "valid_data = next(valid_reader()) # 读取验证数据\n",
    "\n",
    "valid_image = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取验证图像\n",
    "valid_image = valid_augment(valid_image)                                                       # 验证图像增强\n",
    "valid_label = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64)           # 读取验证标签\n",
    "print('valid_data: image shape {}, label shape:{}'.format(valid_image.shape, valid_label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型设计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm\n",
    "import math\n",
    "\n",
    "# 模组结构：输入维度，输出维度，滑动步长，基础长度\n",
    "group_arch = [(16, 16, 1, 18), (16, 32, 2, 18), (32, 64, 2, 18)]\n",
    "group_dim  = 64  # 模组输出维度\n",
    "class_dim  = 100 # 类别数量维度\n",
    "\n",
    "# 卷积单元\n",
    "class ConvUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=3, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化卷积单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ConvUnit, self).__init__()\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size-1)//2,                       # 输出特征图大小不变\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 分割结构\n",
    "class HSBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, splits=5, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始HS-Block结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长，1保持不变，2下采样\n",
    "            splits  - 分割次数\n",
    "            act     - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(HSBlock, self).__init__()\n",
    "        \n",
    "        # 计算通道\n",
    "        channel0 = out_dim // splits\n",
    "        channel1 = channel0 * 2\n",
    "        channel2 = channel0 * splits\n",
    "        \n",
    "        # 特征平分\n",
    "        self.conv1 = ConvUnit(in_dim=in_dim, out_dim=channel2, filter_size=1, stride=1, act=act)\n",
    "        \n",
    "        # 特征升维\n",
    "        self.conv2 = ConvUnit(in_dim=channel0, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        \n",
    "        # 重复合并\n",
    "        self.conv3 = ConvUnit(in_dim=channel1, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        self.conv4 = ConvUnit(in_dim=channel1, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        self.conv5 = ConvUnit(in_dim=channel1, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        \n",
    "        # 合并特征\n",
    "        self.conv6 = ConvUnit(in_dim=channel2 + channel0, out_dim=out_dim, filter_size=1, stride=1, act=act)\n",
    "            \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 特征平分\n",
    "        x = self.conv1(x)\n",
    "        x0, x1, x2, x3, x4 = fluid.layers.split(input=x, num_or_sections=5, dim=1)\n",
    "        \n",
    "        # 特征升维\n",
    "        x1 = self.conv2(x1)\n",
    "        x1_0, x1_1 = fluid.layers.split(input=x1, num_or_sections=2, dim=1)\n",
    "        \n",
    "        # 重复合并\n",
    "        x2 = fluid.layers.concat(input=[x2, x1_1], axis=1)\n",
    "        x2 = self.conv3(x2)\n",
    "        x2_0, x2_1 = fluid.layers.split(input=x2, num_or_sections=2, dim=1)\n",
    "        \n",
    "        x3 = fluid.layers.concat(input=[x3, x2_1], axis=1)\n",
    "        x3 = self.conv4(x3)\n",
    "        x3_0, x3_1 = fluid.layers.split(input=x3, num_or_sections=2, dim=1)\n",
    "        \n",
    "        x4 = fluid.layers.concat(input=[x4, x3_1], axis=1)\n",
    "        x4 = self.conv5(x4)\n",
    "        \n",
    "        # 合并特征\n",
    "        x = fluid.layers.concat(input=[x0, x1_0, x2_0, x3_0, x4], axis=1)\n",
    "        x = self.conv6(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 基础结构\n",
    "class ResBasic(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, is_pass=True):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化基础结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            is_pass - 是否直连\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ResBasic, self).__init__()\n",
    "        \n",
    "        # 是否直连标识\n",
    "        self.is_pass = is_pass\n",
    "        \n",
    "        # 添加投影路径\n",
    "        self.proj = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=1, stride=stride, act=None)\n",
    "        \n",
    "        # 添加卷积路径\n",
    "        self.con1 = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=3, stride=stride, act='relu')\n",
    "        self.con2 = HSBlock(in_dim=out_dim, out_dim=out_dim, stride=1, splits=5, act='relu')\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "            y - 输出特征\n",
    "        \"\"\"\n",
    "        # 直连路径\n",
    "        if self.is_pass: # 是否直连\n",
    "            x_pass = x\n",
    "        else:            # 否则投影\n",
    "            x_pass = self.proj(x)\n",
    "        \n",
    "        # 卷积路径\n",
    "        x_con1 = self.con1(x)\n",
    "        x_con2 = self.con2(x_con1)\n",
    "        \n",
    "        # 输出特征\n",
    "        x = fluid.layers.elementwise_add(x=x_pass, y=x_con1, act='relu') # 直连路径与卷积路径进行特征相加\n",
    "        \n",
    "        return x\n",
    "    \n",
    "# 模块结构\n",
    "class ResBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, basics=1):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模块结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            basics  - 基础长度\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ResBlock, self).__init__()\n",
    "        \n",
    "        # 添加模块列表\n",
    "        self.block_list = [] # 模块列表\n",
    "        for i in range(basics):\n",
    "            block_item = self.add_sublayer( # 构造模块项目\n",
    "                'block_' + str(i),\n",
    "                ResBasic(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组模块项目除第一块外，输入维度=输出维度\n",
    "                    out_dim=out_dim,\n",
    "                    stride=(stride if i==0 else 1), # 每组模块项目除第一块外，stride=1\n",
    "                    is_pass=(False if i==0 else True))) # 每组模块项目除第一块外，is_pass=True\n",
    "            self.block_list.append(block_item) # 添加模块项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "        \"\"\"\n",
    "        for block_item in self.block_list:\n",
    "            x = block_item(x) # 提取模块特征\n",
    "            \n",
    "        return x\n",
    "\n",
    "# 模组结构\n",
    "class ResGroup(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模组结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ResGroup, self).__init__()\n",
    "        \n",
    "        # 添加模组列表\n",
    "        self.group_list = [] # 模组列表\n",
    "        for i, block_arch in enumerate(group_arch):\n",
    "            group_item = self.add_sublayer( # 构造模组项目\n",
    "                'group_' + str(i),\n",
    "                ResBlock(\n",
    "                    in_dim=block_arch[0],\n",
    "                    out_dim=block_arch[1],\n",
    "                    stride=block_arch[2],\n",
    "                    basics=block_arch[3]))\n",
    "            self.group_list.append(group_item) # 添加模组项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "        \"\"\"\n",
    "        for group_item in self.group_list:\n",
    "            x = group_item(x) # 提取模组特征\n",
    "            \n",
    "        return x\n",
    "        \n",
    "# 残差网络\n",
    "class ResNet(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化残差网络，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ResNet, self).__init__()\n",
    "        \n",
    "        # 添加初始化层\n",
    "        self.conv = ConvUnit(in_dim=3, out_dim=16, filter_size=3, stride=1, act='relu')\n",
    "        \n",
    "        # 添加模组结构\n",
    "        self.backbone = ResGroup() # 输出：N*C*H*W\n",
    "        \n",
    "        # 添加全连接层\n",
    "        self.pool = Pool2D(global_pooling=True, pool_type='avg') # 输出：N*C*1*1\n",
    "        \n",
    "        stdv = 1.0/(math.sqrt(group_dim)*1.0)                    # 设置均匀分布权重方差\n",
    "        self.fc = Linear(                                        # 输出：=N*10\n",
    "            input_dim=group_dim,\n",
    "            output_dim=class_dim,\n",
    "            param_attr=fluid.initializer.Uniform(-stdv, stdv),   # 使用均匀分布初始权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),           # 使用常量数值初始偏置\n",
    "            act='softmax')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入图像进行分类\n",
    "        输入:\n",
    "            x - 输入图像\n",
    "        输出:\n",
    "            x - 预测结果\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x = self.conv(x)\n",
    "        x = self.backbone(x)\n",
    "        \n",
    "        # 进行预测\n",
    "        x = self.pool(x)\n",
    "        x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tatol param: 1621160\n",
      "infer shape: [1, 100]\n"
     ]
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.base import to_variable\n",
    "import numpy as np\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 输入数据\n",
    "    x = np.random.randn(1, 3, 32, 32).astype(np.float32)\n",
    "    x = to_variable(x)\n",
    "    \n",
    "    # 进行预测\n",
    "    backbone = ResNet() # 设置网络\n",
    "    \n",
    "    infer = backbone(x) # 进行预测\n",
    "    \n",
    "    # 显示结果\n",
    "    parameters = 0\n",
    "    for p in backbone.parameters():\n",
    "        parameters += np.prod(p.shape) # 统计参数\n",
    "    \n",
    "    print('tatol param:', parameters)\n",
    "    print('infer shape:', infer.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAD8CAYAAACINTRsAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJztnXeYVdXVh999p/fOMDADQ4ehw4AIiICKgIUYK0ETJYmJxhqNLYnRqBE1seWzobFFBZXYUbGhoCJIB+kMZQYYpsH0fvf3xz63TR+Yy9wZ1vs8PPfcfdo69wy/s87aa6+ttNYIgiAInQNbexsgCIIgtB0i6oIgCJ0IEXVBEIROhIi6IAhCJ0JEXRAEoRMhoi4IgtCJEFEXBEHoRIioC4IgdCJE1AVBEDoR/t44aHx8vE5NTfXGoQVBEDola9asydNaJxzvcbwi6qmpqaxevdobhxYEQeiUKKX2tcVxJPwiCILQiWiRp66U2gsUA7VAjdY63ZtGCYIgCMdGa8IvU7TWeV6zRBAEQThuvBJTb4jq6mqysrKoqKg4UafstAQHB5OcnExAQEB7myIIgo/RUlHXwGdKKQ08p7We39oTZWVlERERQWpqKkqp1u4uWGityc/PJysri169erW3OYIg+Bgt7SidqLUeBcwA/qCUmlR3A6XU1Uqp1Uqp1bm5ufUOUFFRQVxcnAj6caKUIi4uTt54BEFokBaJutb6gPWZA7wLjG1gm/la63StdXpCQsOpliLobYP8joIgNEazoq6UClNKRTiWgWnAZm8YU15VS2lljTcOLQiCcFLQEk89EfhWKbUBWAUs1lp/6g1jduYUszu3pM2Pm5+fz4gRIxgxYgRdu3ale/fuzu9VVVUtOsZVV13F9u3bW3zOF154gZtuuulYTRYEQTgmmu0o1VpnAMNPgC1eIy4ujvXr1wNwzz33EB4ezq233uqxjdYarTU2W8PPuZdeesnrdgqCIBwvPjmiVGt9Qs6za9cu0tLSmDNnDoMHD+bQoUNcffXVpKenM3jwYP7+9787t504cSLr16+npqaG6Oho7rjjDoYPH86pp55KTk5Ok+fZs2cPU6ZMYdiwYZx11llkZWUBsHDhQoYMGcLw4cOZMmUKAJs2bWLMmDGMGDGCYcOGkZGR4b0fQBCETscJy1N3594Pf2LLwaJ67Y54emiQP63tCkzrFsnfzhvcalu2bdvGq6++Snq6GSQ7b948YmNjqampYcqUKVx00UWkpaV57FNYWMjpp5/OvHnz+OMf/8iLL77IHXfc0eg5rr32Wn7zm98wZ84c5s+fz0033cSiRYu49957+frrr0lMTOTo0aMAPP3009x6661ceumlVFZWnrAHnCAInQOf9NQ5gULWp08fp6ADLFiwgFGjRjFq1Ci2bt3Kli1b6u0TEhLCjBkzABg9ejR79+5t8hwrV67ksssuA+CXv/wly5cvB2DChAn88pe/5IUXXsButwMwfvx47r//fh5++GEyMzMJDg5ui8sUBOEkoV089cY86o1ZxlvtnxhBcIDfCbElLCzMubxz506eeOIJVq1aRXR0NJdffnmD+eCBgYHOZT8/P2pqji1j5/nnn2flypV89NFHjBo1inXr1nHFFVdw6qmnsnjxYqZPn86LL77IpEn1hgUIgiA0iE966rX29gk5FBUVERERQWRkJIcOHWLJkiVtctxx48bx1ltvAfDaa685RTojI4Nx48Zx3333ERMTw4EDB8jIyKBv377ceOONnHvuuWzcuLFNbBAE4eSgXTz15rC3Uxx51KhRpKWlMXDgQHr27MmECRPa5LhPPfUUc+fO5cEHHyQxMdGZSXPzzTezZ88etNZMmzaNIUOGcP/997NgwQICAgLo1q0b99xzT5vYIAjCyYHyRkdcenq6rjtJxtatWxk0aFCT+znCLz3jwogKkWJVTdGS31MQhI6DUmpNW5Q196nwS0JEEHDiUhoFQRA6Gz4l6rFhpgNSNF0QBOHY8ClRV1Z2umi6IAjCseFbom6NOJLwiyAIwrHhW6JufYqkC4IgHBu+JeqWqy6OuiAIwrHhY6JuPr0RfpkyZUq9wUSPP/4411xzTZP7hYeHA3Dw4EEuuuiiBreZPHkydVM4m2oXBEHwFr4l6tanNxz12bNns3DhQo+2hQsXMnv27Bbt361bNxYtWuQFywRBENoO3xJ1pVAor4RfLrroIhYvXuycFGPv3r0cPHiQ0047jZKSEs444wxGjRrF0KFDef/99+vtv3fvXoYMGQJAeXk5l112GYMGDeKCCy6gvLy82fMvWLCAoUOHMmTIEG6//XYAamtrufLKKxkyZAhDhw7lscceA+DJJ58kLS2NYcOGOQuBCYIgtIT2KRPwyR2QvamBFZp+VVXYbH7g30rTug6FGfMaXR0bG8vYsWP55JNPmDVrFgsXLuSSSy5BKUVwcDDvvvsukZGR5OXlMW7cOM4///xG5wJ95plnCA0NZevWrWzcuJFRo0Y1adrBgwe5/fbbWbNmDTExMUybNo333nuPlJQUDhw4wObNZnZAR/ndefPmsWfPHoKCgpxtgiAILcGnPHW0nWCqUNi9cnj3EIx76EVrzV133cWwYcM488wzOXDgAIcPH270OMuWLePyyy8HYNiwYQwbNqzJ8/74449MnjyZhIQE/P39mTNnDsuWLaN3795kZGRw/fXX8+mnnxIZGek85pw5c3jttdfwb+3DTRCEk5r2UYzGPOqaCsjZSnFgErHxXdv8tLNmzeLmm29m7dq1lJWVMXr0aABef/11cnNzWbNmDQEBAaSmpjZYcretiYmJYcOGDSxZsoRnn32Wt956ixdffJHFixezbNkyPvzwQx544AE2bdok4i4IQovwLU9dmRrqSnvHUw8PD2fKlCnMnTvXo4O0sLCQLl26EBAQwNKlS9m3b1+Tx5k0aRJvvPEGAJs3b262PO7YsWP55ptvyMvLo7a2lgULFnD66aeTl5eH3W7nwgsv5P7772ft2rXY7XYyMzOZMmUKDz30EIWFhZSUtP1k3IIgdE58y/1zinqt104xe/ZsLrjgAo9MmDlz5nDeeecxdOhQ0tPTGThwYJPHuOaaa7jqqqsYNGgQgwYNcnr8jZGUlMS8efOYMmUKWmvOOeccZs2axYYNG7jqqqucsx49+OCD1NbWcvnll1NYWIjWmhtuuIHo6Ojjv3BBEE4KfKr0LlqjD62nyC+WqMSebW5XZ0JK7wpC56JTlt5FKezYvNZRKgiC0NnxLVEH7NiweSmmLgiC0Nk5oaLeklCPHZvXOko7C1LFUhCExjhhoh4cHEx+fn6zgmRXNmx4r6O0o6O1Jj8/n+Dg4PY2RRAEH+SEZb8kJyeTlZVFbm5uk9tVFR5GYSfgqHijjREcHExycnJ7myEIgg9ywkQ9ICCAXr16Nbvdyodvp3vlbpL/+tMJsEoQBKFz4XMdpYW1QfjXlErcWBAE4RjwOVHfX+pHOOVkFjRf+VAQBEHwxOdEPTE+gXBVQZlVIlcQBEFoOT4n6mm9ugFQVlzYzpYIgiB0PHxO1P1DogCoLJU64oIgCK3F50Q9IMwh6uKpC4IgtBafE/WgUCPqVWVF7WyJIAhCx8P3RD3ciHqtiLogCEKrabGoK6X8lFLrlFIfedOgkPAYAOzlR7x5GkEQhE5Jazz1G4Gt3jLEgV9EIgA7M/Z4+1SCIAidjhaJulIqGTgHeMG75gChsdRqha08lw2ZkgEjCILQGlrqqT8O3AYnYPYKmx/5RJFAIeXVUq1REAShNTQr6kqpc4EcrfWaZra7Wim1Wim1urlKjM2Rp6OIV4UE+vtcP64gCIJP0xLVnACcr5TaCywEpiqlXqu7kdZ6vtY6XWudnpCQcFxG5elIElQh6riOIgiCcPLRrKhrre/UWidrrVOBy4CvtNaXe9OoXKKJV4VU10qlRkEQhNbgk/GN0oBYEiikukZi6oIgCK2hVaKutf5aa32ut4xxMDV9KEGqGnuFlAoQBEFoDT7pqeswE5OvKjzczpYIgiB0LHxS1AnrAsBzi1e0syGCIAgdC98U9XAj6l3UUTi0EWpr2tkgQRCEjoFPinp1qCkVMNq2A56bBNu8Wm5GEASh0+CToh4b14UKHcAptm2AhhKJrQuCILQEnxT16LAgsnUsA9R+01AhZXgFQRBagk+KOsBhYvBT1uCjSkltFARBaAk+K+rZOtb1RTx1QRCEFuHDoh7j+lJZ3H6GCIIgdCB8VtRzPERdPHVBEISW4LOi3rdvP9cXCb8IgiC0CJ8V9dFDBru+iKcuCILQIvzb24DGqI4byHp7HxSaYRVFUltdEAShBfisp24LieJnVfex2j5APHVBEIQW4rOi7m8zvnkxIaiqErBLbXVBEITm8FlR93OIug4xDZLWKAiC0Cw+K+oBfsa0YkJNg4RgBEEQmsVnRd1BsXaIunjqgiAIzeHzol6CFX6RXHVBEIRm8XlRd3nqIuqCIAjN4bOiHhceCJjsF0A8dUEQhBbgs6IeGujPqrvO4JCOo1r7oQ9vaW+TBEEQfB6fFXWA0CB/yghmg+6D3rusvc0RBEHweXxa1EMC/ABYYU9DHVwnGTCCIAjN4NOi7hiAtMKehtK1sG9FO1skCILg2/i0qDvYYu9pFgp2t68hgiAIPk6HEPVCwtDKBmX57W2KIAiCT9MhRF1jwx4cI6IuCILQDB1C1AHsIXFQmtfeZgiCIPg0Pi/qd80cCEBtcCyUFbSzNYIgCL6Nz4t61ygzorQmOFbCL4IgCM3g86LumCxjyZ5q8nMPtbM1giAIvo3Pi7ojV/1QdShRugjs9na2SBAEwXfxeVEP8DOiXqAj8Vd2qCxsZ4sEQRB8F58XdZtyiHqEaZDOUkEQhEbxeVEvLK8G4AgOUZfOUkEQhMZoVtSVUsFKqVVKqQ1KqZ+UUveeCMMcxIcHAW6euuSqC4IgNEpLPPVKYKrWejgwApiulBrnXbNcTOgbzyXpyW7hF/HUBUEQGqNZUdeGEutrgPVPe9WqOlw9qQ8FVvilVjx1QRCERmlRTF0p5aeUWg/kAJ9rrVd61yxP+nYJ5+YZw6nQAdhLxVMXBEFojBaJuta6Vms9AkgGxiqlhtTdRil1tVJqtVJqdW5ublvbSaC/HwVEYC9p+2MLgiB0FlqV/aK1PgosBaY3sG6+1jpda52ekJDQVvY5CQrw44iOQIunLgiC0CgtyX5JUEpFW8shwFnANm8bVpcgfxv5OhJVLnnqgiAIjeHfgm2SgFeUUn6Yh8BbWuuPvGtWfQL9bRwhAlV+8ESfWhAEocPQrKhrrTcCI0+ALU0S5O9Hlo5AyYhSQRCERvH5EaUOAv1tHNERBFQXUV5e0d7mCIIg+CQdRtSD/G3OXPVJ9/6PvXml7WyRIAiC79GxRN0aVRqjitmdW9LMHoIgCCcfHUbUHR2lALGqGJtVZ10QBEFw0WFEPcjfz+mpp6gcRNIFQRDq04FE3UaejgLgkYD5JO474VmVgiAIPk/HEfUAG/lEcW3VDZToYKLy1ra3SYIgCD5HxxF1fz8APraPY5fuTkjh7na2SBAEwffoQKLuMnW37kZIkYi6IAhCXTqMqAf6uUzdZe9OcPlhqCxuR4sEQRB8jw4j6u4pjLt1kll4dRYcXN9OFgmCIPgeHUbU3dmtu5mFA2tgzUvta4wgCIIP0SFFfa/uSnbcKeZL8eH2NUYQBMGH6FCinhoXCkAtfnw66jkYdD7k72pnqwRBEHyHDiXqXSKCncsllTUQ1xeO7IHa6na0ShAEwXfoUKKeEBnkXC6urIH4fmCvgaP729EqQRAE36FDifqDPx/Ko5cMB+C5bzKMpw4SghEEQbDoUKIeGRzAz0cluxocop63s30MEgRB8DE6lKg7cHSYfpNZA4HhUJjVzhYJgiD4Bh1S1M8e3BWAX730I0QkQbFMRi0IggAdVNSLK2tcXyK7QZGIuiAIAnRQUb9iXE8A4sMDLVE/1M4WCYIg+AYdUtQHJUVy9uBE8kqqqA1PgpJssNe2t1mCIAjtTocUdYAlP5nyAFtKwkyuemluO1skCILQ/nRYUR+ebKa2I7K7+ZS4uiAIQscV9b/PGgJAni0OgJUbNrWnOYIgCD5BhxX14AAzvd2fluQB8NF3azlcVNGeJgmCILQ7HVbUHdPb5RNBlfYjSRWgdSsOUJoH5Ue8Y5wgCEI70XFFPcCYrrGRQwyJqoAau73lB3j7Slh8q3eMEwRBaCf829uAYyXI38+5nK1jSaKAWnsrXPXiQ6Bb8RAQBEHoAHRcT93fZXq2jiVRHSFo52I4srdlB6gqhcoi7xgnCILQTnQSUY8hWeWRuOR36BVPtewAVaVQWewl6wRBENqHDivq/n42Hvz5UAAO6ViCVDVK21mzYUPzO2stoi4IQqekw4o6QHrPGACydZyzLay8BYOQaipB14qoC4LQ6ejQoh4ebPp5s3WMsy1Z5ZHbXL56Van5rK0yAi8IgtBJ6NCiHhbkEPVYZ1uEKmf19oymd6wqcS2Lty4IQiei2ZRGpVQK8CqQCGhgvtb6CW8b1hLCAo35OcRQqxU5xJCkCvi/d5ayqyiAqNAALhvTg0D/Os8uh6cOUFEIYfEn0GpBEATv0ZI89RrgFq31WqVUBLBGKfW51nqLl21rFj+b4opxPfnvD/v4ffXNaBQvBP6LZJXLvz7fAUB+SRU3T4gH/2AINNPgeYh6U556VakJz4TGNr6NIAiCD9Fs+EVrfUhrvdZaLga2At29bVhLue9nprDX5/Z0frQPACBZucrwHi6qgJfPgc/vdu1U3UJRf+50eLhXm9orCILgTVo1olQplQqMBFZ6w5jjpZAwCnUofZQrA6ayqhJyt0NgmGvDlnrq+TvNp9agVBtbKwiC0Pa0uKNUKRUO/A+4SWtdbyimUupqpdRqpdTq3Nz2mrBCsc7ej3TbdmdLcNlhk76YtxNnxa+WirqDkpy2M1FrKSQmCILXaJGoK6UCMIL+utb6nYa20VrP11qna63TExIS2tLGZrlt+gDn8ir7APrbDhCNEeuISmv+0oqjUFZglj2yX1pQKuDovrYyFfYsg0f6QuGBtjumIAiCRbOirpRSwH+ArVrrR71vUusZ38eVvbLaiqun20xHaVSF22AkRziltZ760f3HbaOTI3vN9HtFIuqCILQ9LYmpTwCuADYppdZbbXdprT/2nlmtY0RKNG/85hRG9Ihm5N1VVGp/Jto2McfvC9KLd7s2zN8FPcZBVZn5rvyaFnVlM5Uc29JTd7wlVBS23TEFQRAsmhV1rfW3gM/3Eo7va7z1SgL53D6aOX5fEqBqTWZ9WBcTx85zeOol4B+CDghGNSbqdrurNG9beuqVIuqCIHiPDj2itDHeqD3DCLqD0hyI7Q35u8gpruCTtbup9gshq9SfrMOHGz5IlZvYHzlGT91eW7/NcdymRL2iCBbOkbi7IAitplOK+gp7Gt/UDmNe9WUA5IekQlxfdP4uXvl+L2WlheRV+VNCCKW5mTQ4D567B3/4J6itbp0RGxbCw7094/fQMk89eyNs+wj2r2jdOQVBOOnplKKusfGr6jt4tvZ8zqu8n3OO3EJNbB9q83bzzNKdhFFJYW0Q39iHM6B8HXzzEFprthx0y4SpsJZHXG48/U1vt86Ita+ajJujmZ7tjph6U1k3jiwdx6cgCEIL6XSiPigp0uP7Jt2bbOJ4d18w/rqa9wL/ynS/HykjiIdqLuV7vzGw6nkWrd7PzCeX89U2KxzjEN0hF0DiUFj2CFSXt8yI4mzY971Zrpvl4ngDKD8C+xrxxB157OUi6oIgtI5OJ+pv/W4cn9x4Gn+eOYj48CBn+5t7ggEYZtsDQJrah8bGwvKxUJbH0V1mkGxGrhUucYhvUBScfT8UZMCX97XMiK0fYnpogaI69d0d4ZfN78JL0yF7U/39HWJelt+y8wmCIFh0OlGPCA5gUFIkv53Um9V/OZM0y3Pfo5M8tltt7w/AMvswtLLRt9B41n42K9HHEfMOjoTek2HMb+GHp2Hvt80b8dO7ENvHLBcf8lzn6CittI7f0Jyq7uGX2hr4eh6UtkDgd33RtqNfW4JdJu8WBF+i04l6Y+QTSaEOJcPelZEVz/L76psBOEoEq2r7M+TIlyjs+DtE3empR5jPs+6F2F7wzu/qe9+b/wcvnGUEzhF6GXoxhCVA5kp46Rz4vzHmmJUlnvsWHTIdsY/0hbxdsGe5S9TLC+DgOvj6QdNx2hRVZfD6xfDDM8fxK7WSbx+Hv8dCdTOTkgiCcMI4aUQdFM/VnMfjNRdyhEhKCHWueaNmKgmV+5lk24TNKepWTD3IitEHhsHFLxsPfuEvPA+9+yvIWgVH98L2jwENg38Gkd2M97zvW8jbAbk7PEsUgIm5//QulObC1/+AV86FrR+YdWUFJuwD9T1+B2tfNQ+ZwkxroFQb5tQ3xw/PANrY0BCb34F/Dmh5X4QgCMfNSSTq8HTtLD6wT6jX/rF9HId1NNf7v8vd71oTV1cWmxGl7tUdk4azZdD1xnvOdxupWmDi9GRvhsNbzIMgYSBEWhWKHQ+Gwv31PfXiQ0b4AbZYYu54oJQXQIF1nrpvB462D66HNa+4cukLs1rwS7SSDQtNCMjB1o/gy79DglVz5/sn3YqllZmqmAA7P4OS7PoZQIIgeI2TStQboxp/Hqq+jHTbDl4JmGfCKRVFJvRSp+Tu71ZZxcp2fu5qdIj64Z+MZx3b2+zneCCMmGM+j+zzrOUOpqP04DqzbK+TC19W4Hp4FB8yD4QXZ0Dmj6YtZ6v5LMx0lTJorKZMVZln/Ptopiv+/sENnvXm7bUmlu9g3WumP8Eh3G/OgeX/cu1fmGnsrCyBRwfBU2ONd35gTeM2lebDV/d7nkcQhONGRN3iHfsknqz5mcmO+exu46kHRdbbLlMnstueZLxQMOJVbHnRhzcbzzq2t/nuyF7pd6Y5Vu428z04ynXAHGsCqeQx9Y2qKnF5vUWHIPMH2P89bH3ftDmO5yHqB40or3oeDq43D5yjmfDUKcarB6ipgv9MM3H+jG9MmGTHEtd53/09LJzt+l6QYcJOZfmeo2sLdkPPiWZ5//fGY684ar5nbzIhJ2hY1DcvMmmih9bXXycIwjEjou7GozWX8ETNz6Eoi13rl2EPifNYry1PdZl9mOkMra1xZa/4BRkRPbof4qzMl2n3w/DZkDoJolJcAh6ZbD5DrGnyQuNd3ryybklwtPk8bKU8Fh+EzFVm+eB6I/KO4x3NdMXSda0R049vhe8ehydHwuNDTOhnwxvGo972oTmetsOiq0xGTkGGeRhoDbs+h91LzQOrqswlyvm7TCjGQW0V9BwPoXEm537PMpf9m//n2q6hcgeHNlrrmggXVRTCG5c1309QfrThUcGCcBIiol6H9XYjyH1VFkVdx3msq641wrHB3gdqyrnwvpd56HWrWGX/s6Eoywilw1NPHAwXPAv+gRCd4gqXRFmx9pSx5rPvGdBnKsT3h7RZpi2ur+vEIbHGS96zzHzP+hGeGGbCImBEt2AP+AWa745Mme2f4MyXH3ox2AJg1XxY/RLEpMKEG11vE7VVRmDzd5vBT/ZqyFrtmXK5+BbTmRviNmdreBfocarVWbza9XByiHpguPld6pJt9V24i3ruDlNawfE7HVgLOz6BLe/X399BWQE8NhhWPtf4NoJwEiGiXoefdCo12vwsm0M9QyKVNaZA12adCkDPqp1U51kx79Nvd23oyFF3JyrFCCdA37Og3zQYMMP6fibE9ITrfoQ+Z5g2d1FPO9987l9hBLWmwnUsMMvZG6F7uvnu6HCtsVINr18LP38ekoabztyD663zz/S0MWerCaM42Pe9q6MWTHip50T4vVuufngiDDzXdIjaq801hcSabJ4ug811ODp5y4/AxrdM+CfHCh3lbDGdsJUlJoOoLN/1QHDs53hDaYhdX5gw1ZqXTJz+mYkmm8idjW+7HhSC0MkRUa9DBUHs0CmU6iDmfuVPYXk127OLqaypparGdDRm6G6U60CG2PZymm0TRPWArkOMJwwuT92d6BTXcuJgmPM2pP0Mxl8PA89xrYs3g6JIsR4o6XNh0Hmu9adeaz77nmnEc+TlrnX9zjSf2RtxVksOS3B13Mb2Np2yVcXmwdNlEET3dJ1zwaWuuHviENi73JVS6WD0lSZV088arRveBYZd6lqfMs4VfkqdCFHJJvyiNcyfDO/81sTTHZ3C6183efj/+7XrXGtegeWPur5nrmo8vOLoC8jdBq+eb8JVa152ra8shnevhu+eaHh/QehktGri6c5OckwIWUfKebbmPGJVEVUEMPxe0yGaEBHE7yYZsbZjY4vuyWTbenqqwzDsFnOAP6w0WTHh9afz093TXUXpg8LNZ0i0ibu7kzIW5iwy4ZiB50JEV5dXG9YFJt5iBHXIhSaMc/gnVxhm9FWw60vY9x2knGIGPqWc4srgie3tyr6J62Pa57xtwjZPjrDOkWDeJMITYMVTJl4eGmfi2/Ya44krZYT9yB6zvc1m3gayVkNYnDlP1o/Q6zQzmGrPMuN9O0I5PzxtPmN6mWMA7PjU1SlcmgNf3gshMeZ7SbbZNzTWHKf/DIhMMh7/ri/M77R/hatevvuo2sxVJiTm6LRtjG8fM30bo64w33d/Zfbre2bT+wmCj3HSiPpLV41hY2YhiZFB3PFOA/VWgPOHd+Ppr3fzgX18vXW5xZXcv9j1Cv+jfQC/9zex6yfyRjE58yhJUd2JGXM185fuond8GDOGukoTvHygGxU153GN/4eu/PWGUAr6nWWWI7qaz4QBcN4TMOh8I6ATbnBtH2W9AcT1M6L3izfh87/B0ItM+KSHW7+A+xuEYznBNb8rADdthoBg2L/SeLdb3jNhmsl3mtGyjgdSVLIR5PAu1vn7uDz0LmnmraXnBBPrrywyo09jUk2qaPYm84aQOtEcI2k4HNpgLY8wIZ2dS0y4JjjKeNtPn2p+m+oyCLzbtHcZaLJt0ufCxa+YTtqlDxiBrq6wruMHY5Nj4nHVwHwvR/aauj5d0oyoa23SPANC4Dovi3pVmXlQjby8YdsEoZV0elEf0DWCLYeK6BUXxpQBRoCVhyOYAAAgAElEQVQcor7j/hmc/3/fsi27mOun9qXW3vIMisdrLuSQjsOfWv6zDh5b9x0A6T1jWL3PVFncO88VVrn3wy3AbMb++l+MDotv6JCNo5QJezREcCRc8R50G2m+B0XAudZUsj3rPJwcQm7zN6Lqzm+XmvMEmMJnJKcbL7w0F067FZKGmX8OIruDf4jpCK3L2Kuh/3TzkEmbZfLRD2+C0+8wKYw7PjUx/TDrjab/DBNTL9ht3ixmPgz/Hm2ybVLGwWm3mDi5tkPv081yQYbx0vueaTqaHSQNMxlAuVvNb+KoSV9ZZB5KkUkmX3/9axA/AHqcAiuetvbZZurmH/7JpIkqP6iphG8eMoPKLnsdbH6e1/rtY8ae335tHriNcXQ/bHjTXIv7dhveMB3QXdIgeXTj+wtCC+n0ov6PC4Zy8ehkUuPD6q0L9Lc5C3iNSIkmNNCfp7/eXW+7hqggiFdqz67X7hD0xlCODJW2pM+Ulm0X28t8RvcAvzq3vvsoz+82PxPvL8xqWGzGXm3eAhryLgNDjRcNpgN4/HXGUx92iTmvQ9RLc61zjzbncc/xTxhoRD2ymxHeHqe4jj/wHCO2q+bD4J97nrvrUPM5fzJc9JIJvyQONQ+VvB1G1D+83oSsonvAr7+Adf91PcAObTRxfrCEfjus/a8JCX37GEy61fN8G9+GnJ/gG6uzd/o/6v8eAOvfMH0HabMgob+rPWu1+czbIaIutAmdvqM0JNDPOX+pgy9vOZ3Pbp4E4CzgFeTvx6l94jitXyu96FaigLd+zOT0R5ZSUuk5mvL7XXmk3rGYnYebmAz7eAiNNfnvDWXnNMSEG2HmIw2vSx4N6Ve17DhT/mwye+L6mLz9U64xHnf/s81yr9NcYSJ3UYfGQ1X+QeahE1VnfXQq9Bhv3iAWzYXaSjjzHrPujUtgx2fGY04carzn1y8yIZ0ZD5ltXpgKq/9jvHgwk6OU5pj4/vdPes6AVZxtBB2MN//DU67RxQ4yvoEFs11jChzbO3Bk9jQX8xeEFtLpRb0h+iSE0z/RVF8M8LNZn0bcpw/p2mbnyTpSBrgGLQHU2jW3/W8j+/LLGPK3JR7bf7jRpPCt2tv85Bhaa/67Yi+F5a2cZu+Mu2HcNa3b53ix+bni7VHJMGOeEeWweLMcEGI82El/gl7mYUuXQeYzslsrz2WDuZ/AJa8C2sT1+55hOoNrKuCd35jMm0m3QK/TTabQgHNMf4WDM++F335p+gVW/J9pm/oX01nsXnp5+yfmMybV1bboKpNW6RD/ZY+YIm87rfo+h91Evcyttk/+Tld73i5T9VPmqBWOgZNS1N1xiLojnq5ourOqe3RIi4/9m1fMq/WPe10hmcqa+vXHy6tq+eynbKcNfo10mO3NK2XBKjO6cl3mUf76/k/c+c7GFtsDwJhfe8agfYXgSCOcjph+yinGS+9+jCGJvmeYzKKzHzAhot8tMymkjjr53UfDL9+HW7bDxS+BX4DpawA49TrTNxHR1cTxk0aYQVXKD/77M3hyFMzrAR/dZAT90tdg1tOms/rgOhPq2fm5Gci1d7k5piPr6PAWk1nz5uVWRU9MamreLpftX/3d5Ozv/tII/9IHzahZQWgBnT6m3hz+lodeWdv0ZA9nD05kyU+HGZMaw6TABKe4NkVBqRkgdMlzrmnrqhoQ9Xs++Ik3V2cysKt5e7A1IuoXPfs9eSVVXJKeQkV1rcc5Oh3RKfDHLcd3jPHXu5a7DDIZJlveM6mLUSlG7CPc3syuX2OE3dHfMOlPcHCt6SgOCDFppJvegsQ0iEgyoaIhF5oMoK5DTR/AqufNCOLvHjeTjtsCwD/YjA1QfrB9sRHrmgrj6Ud2NymZq54zdffHX+8aQZu12nTC7v7K9FNMuNH0Pbx5Ocx4xDWWoSHWvW6yi2J6Nr6N0Ck56T31O2cMYmj3KMakxja53ageJmd6YFL9Il+N0VA2jWNUqjt78o0XV2SFUpw13QG7XZNZYMI4DgGvqK51jv4XWkGPU43Idh/VcAdvTKoJDzkY/SuTSuoYOHb+v+GO/cYzn/mICWM5UjrBpH3euB5GXmHGCBQfgtkLXP0FqVbZ59g+ZkCZvcbsk5hm2vd9aypgBkdBt1Gmc3X3V6Z/wCH0H99m3gaW/wtK88xgsQ9u8CwFnb8b3r/WhH7AZPcc2nDcP5/QMTjpPfW0bpF8eP1E5/cg/4afcz8flUxKbChnD+7KXY3kudelVmuq67wBLFrjGSetqbU7BdrxDLj9fxuZNjiRyOAAnl22m4c/3c4956U511dU1zo1vblwUVN8te0ww5OjiXOby7VTExQO5/zLNYK2tQQEu8JDDeEfCP7xMOUuM8o2rq9pO7jepF/OeAQylsKoX5kO3AEzzXaOEsuleSbF8dTrTbjm4FqTejrql/DVfWb8wfbF5rg7PoXnTjcZOzZ/yPjahJMOrHEJ+M7PTKG1JXfC4AvMJC9Cp+ekF/W6zBrRjf0FZTzx5U6P9tBAP2Zag4mqGgjVzJ3Qixe/88x8OFpWzT8+9qw58sXWwx7fy6pdnrvd6lCttWte+W4v15/Rj+925QFwz4euUER5dX1vH6C61s6jn+/gmsl9iAwOaPI6y6pqmPvyaoZ2j/J4qHV6Rv/K++fwD3J53wDjfm9GCncZ6Er1JBRGWDNohSfAtPuMuEf3MKUgHBOnpM81IZ7vnzQhnT5nmDeGl2dCRDcj1DYbvHyea1QwmIFYJYfhtQvN94xvTH6+zWYGf1UUmkFP8f3NQyIqxc22BqitNjWGAuunBgu+hYh6Hfz9bNx8Vn9mDk0ip7iCK/5jUs5CAlyDThoKoQzo2sAgHOCl7/Z6fB+UFMnWQ0XO7xVVtWjL73b36v/1+Q7+9XnDaW4V1fUfKtmFFXy57TDPfL2b8qpa7jl/cCNXiHUuc869eaVNbie0AUERJoWzOUJjYcqdZrnf2SZTacxvzBvG9WtNf8CQi0x5iRvrhFOu/R5Wv2jq9n89z4SGvn/SePw9J5jSEetfM7H2TGuULQrna2JQpEk3zdliHiw2P5O7f9ot5gHw+iXm7eJ3y0xFzKKDMP3B+oOxhHZHRL0RBnSNYIDVcQmecW6HqIYE+FFeXcu0tESCA5r/4x7YNcJD0AEWrc1y1qoqrWzYA69LRR1PvabWzrgHv3R+t7egtniN9QCRkek+SkCwEVQHYfFG4BsjJhXO+rtZHn2VyasfdJ75DAyHRwea+HtYAsx42GzfZZCpzaPt8P2/Ye0rZmTr5v+ZttoqU1gNbTp5dS08e5or/TJjKQSEmvEKQy404SPHADd3amtMeeiIJBOOEryKiHoz3P+zIbyz1rMeeGKkiUEnRQeTkVvKmWmJznWn9Ipl5Z76eeazx6awIbOwXvvDn253LjcU1mmIiupa5wAlpaC0ylPkA/2a7/92eOruDyuhkxBmTe7iqNcPZgCYzd+Ec0KiXe0jrfr3Iy83k6T4+ZuaOyiTlpm71QxY6z4Ktn1s4vSn327CMRlLTZ2cD280c9aWHzV1fPJ3mc7ekGiz79H9Zmau4GgTcspcZTKHuo82ZZPDEyFvu8kSCop0zQwW38+acavAjHMoyTFvCCljTQmIogMmHTTQNYl8m+FwjDqg1yOi3gyXj+vJ5eM808L+em4a4/vEs/VQEU9/vdtDRLtEBrPzgRk8sHgrL3+/19l+ev8urNvfNrnGK3bnO0MzSsGDdeL2S7Zks2htFv+ePZLT+tWvGAmuUI+fUmTklhAfEdRsHN7bVFTXsi27mBEp0c1vLLSO029rer1SrlROR3XMYRd7bpM6sU4ZhL+ZB8E7V5tsn75nmjz8oRebUbrlR80DIibVpGqufcVU6EwaYR4Im95yHcov0Bqw1cK0rqAoqCwElHnjOJoJ3UaYjCNlMwO3epxi3lJytprU1cBw8/AoyTUhsegU05cQ3cO8mRRnu+Yg2P+DmfkrPMFURY3rY66lrMC0OYrDVRSZ1NXIbuZB5ai71I6IqB8DoYH+nDe8G2elJRIfHsR5w7vxyeZDgBmZGuBnIznGc5BSbFggY1Jj2ZZtPOzT+yfwzY7cYzr/sp2u/b7blQ/ke6zPLCgH4Ir/rPIoKuZOlTP8opj6r28Y2j2KD66bgFKKp7/eRWRwQL2Hmbe5+/3NvLU6i29vn0JyjBe8L6HtsfnBRf9xdcI2xcjLTTmEpOFm+5Jsa2asg8YrVzZTkdNR5jl/l+lnCIoyIZ+Irqb89J5vzMjexMHmwZH1o/HcM1eacs7+QSaNdOcXJozVJc0IdvkRs0//FHOOo/vhwGrTV4E1ZiEg1Ah899EmXbWi0Ah99kbzJhHdE8oLrHCU3Qyai+ruOb1kOyOifhwEB/gxd6KJIZ49uCtzJ/TiuqlmxqKz0hI9SvX2ig/jL+cOYtGaLMqrazlnWNIxi7r7CNVjxTEIyhF/33SgkF53fsw/Lx7uDAmVVNZw5fhUZ3/By9/tIc56iHmD9ZnmTaa4oqaZLQWfozlBBzOAK2m4a3tHCYhgt7EfwZGu7+7x+Xi3mcBG/MKVOdQW2GvNA6UDhloa4qQffNRWBPjZuPu8NGLDTEdQfJ3c7/jwQIL8/bh2sqmBklBnfb8uDWfPHC/lVbW8vTqTzQcK6XPXx6TesZhZ//etM/xSUkdAn1rqGq4+75NtDL/3M1ZbtWju+XAL1y9Y5xU7wZWn39iIWkHwCja/TiPoIJ661wgL8ufZy0eTWVBGSKAfyvqjuXZKXwZ3j2TygARe/80pzHlhJdDyTtLWMujuT+u1bcgqdIp63fPuqZPiWFlj56JnVzQYxqmqsRNoDdbKLqwgI7ekXkXM1uAofKbRaK2dv5kgCC1HPHUvMn1IV347qbdHbNrPppg6MBGlFBP6xvPEZWbASGsm6GgLLnxmRfMbuZFTVOFcfmrpLr7fnUf/v3zCR1ZlyRlPLOMX1gNq1Z4CvrQGWX2zI5fNB+pn/TSEI+Fg+uPLeaMFtXUEQaiPeOrtTGqcGaE3tHsUWUfKm9y2b5dwduWUAJ4zLJ0Ixv7DlQf/yBJXGuYbK/dTa9ccKTN1axas2s+dVhmFvfPO4VcvrnIuN4d7fv2f393MnFOkGJUgtJZmPXWl1ItKqRyl1OYTYdDJxvCUaN77wwRuPqvpeiRXjk/lqV+Mont0CPf9bAiLrqk/j2p78P3ufG5cuN75/U63ujivuKV0OjhwtJzSyoY7Qk/wy4ogdEpaEn55GZjuZTtOakakRDun1esVH8ame6bV2+ae8wczoGsE390xlStOcKrhsfK3D36q1zZh3ldc/p+VHm2Pfradv7y3qd5I2Jo68f6fDhbyqZU6ejxorXlnbVaD5R4EoaPTbPhFa71MKZXqfVNObhxdglprItwGAb1z7fh6mTIOlt82hZziCmd8/Len9eL55aao2JDukWw+UNTgfu1BcUW1Nfk29QZhPfnVroZ2obiihsLyaib/82vOGZrE4k1G0B2hnEueW0FpZQ2LbzitVbYs3Z7DH9/awPbsYu6cOai1lyIIPo3E1H2EbtEhBPrZ+NPZnpXyHHXcGyIlNpSU2FDumzWY8GB/LhiZTL/ECG5btJEFvx3H0Hs+87bZLeaudzfz4YaDzu/bsosI9vfj6v+ubnSfQ4UV/PeHfQBOQQczqchdMwexyirHUF1rd85g1RIcUwBmu3X+CkJnQekWFH+yPPWPtNZDmtjmauBqgB49eozet29fG5l4crJu/xFKK2uZeBwTYX+3K8+ZMunOV7ecztR/fXM85rUJp/aOY0VGfqPro0MDOFrW8BysT84eyQ1uOfO3TuvPDxkFPHvFaMKDjK+itaayxs4LyzO4ZEwKXSJMLfT31h3gpjfXc/7wbjw5e2QbXpEgHDtKqTVa6/TjPU6bpTRqredrrdO11ukJCQ3XGxFazsgeMccl6AAT3HLGbzyjn3O5d0I4H7uFLJ7/ZTpTB3ZxDoxycMW4nkSHukJBw5KjjsueujQl6ECjgg6eKZYA//xsB9/uyuOrbTkAzro8A//6Kf/8bAe3vm3mct2TV+qscvnBhoMsO8ZRvYLgq0ie+kmCu6iDmfHpHGvSj6kDu/DilWO4bfpA/nqua3KHW88ewNq/nMUtZ/Vnw93TGNyteVG/bkrfZrdpC9xLMLizYOV+Jj28lBlPLPdIvVy2I5eVGflM+efX3OGWoeOY2GTp9pxG53t9e3UmWUfK2tB6QfAezcbUlVILgMlAvFIqC/ib1vo/3jZMaBt+d3pvnvsmA5tN8fScUR7lCB6/bAR/Oz/NmXkDMMnt7SDI34bNprje+UBoPFTXNTKY7KKKdq+w2JT3f+n8H+q1ZRdWcMmzK1i1t4DkmBD+ck4a+aWVzhz5qho7f1pkvPwnLhvBrBHdvWO4ILQRzXrqWuvZWuskrXWA1jpZBL1jceeMQc5skZlDk+iX6Jr4I8DP5owzOwh0m6O17nytju6Xi0Yne7Q/cMEQvv7TZBZePY4z0xI5Y6BrMublt01ptc2O+jkngm3ZxayyattkHSnn96+t4c/vbqbcqlHvPnWgez5+Q3yzI5ddOcXeM1YQWoCEXwQPHJ2MQL3aK9dO7svw5Cj+PHOQR6x9VI8YggP8GNfbTM7w7BWjWX7bFHY9MIOUWFcJ3WV/muKxnzvL/uQS/y//2PDUb5ekJzfY7g3+9sFmDh4tZ/i9nhlE3+zI5fvdedjtpj5N6h2LSb1jMS8sz+BXL67izEeXnTAbBaEhRNQFD+LCg3jpyjEsuWlSvXU94kJ5/7qJxIQFMmWA8cZ/f3ofBrpN+wfmDSAlNhR/tzTDiGB/esSF8uF1ZpLrqJAAj0FWPeJCGZsayz3npREa5Dk14G8m9uLVuWM9jtcY19Tp7D1W1u0/ymMNzBH7qxdX8YvnV3LB0995DKKqG+MvsUbNVtfaefjTbRxpJF4vCG2N5KkL9ZjiFj5pjHkXDuUPU/rQt0tEs9tuuHsafn7G60+JDW20Dsxbvz8VcFVrdPAXq/P2A7c8d4CrJ/Vm/rIMj7bbpw/kq605bD/ccBhkyU2TOPvx5r3pnTkldIsOaXT9hqzGi5Q5UiZ/Pqo7SVHBPP31bkora7h3VqMZwYLQZoinLhwTQf5+LRJ0gKjQAI+wTnO4h32emTPKuVxdp2zAxDplfu+bNRiAV389lt9N6s1Fo5O5ND3FY5sBXSO4+9w04sODuGum50CvuhzrJCY3vWli7++sPcBTS3cD8MqKfSxak9Xg9lrretd2rOQWV9YrnyycXIinLrQrn950WpOjQWdYaZfgmq1p9tgU/G02TusXz6tzxzKudxx78koZYIWBEiODncP/q2rsvLk6E4BVfz4DgLkTezF3Yi/25Zfyj4+3eeW6GuLWtzegteZi60Gz5WARSVHBzF+ewTNfG/Fv7C3Gbtcs3Z7D1IFdmqwzf+qDX1Jj1y2qiil0TkTUhXZlYNfI5jeycHizUwZ0YdrgrgBM6m8Gug3o2vBbQ6C/jXevHU9MaGC9TJ+ecWHcPn0gD3164oT9T4s2smhNFiutEgd1eX/9AY6UVnHlhF68t+4AK3bn8+DPh7JobRa3LdrIQxcO5dIxPQB47Yd9JEYGE+hvY+Gq/Tw9ZxQ1jZS63Jdfyn9X7OOumYOw2WTykc6MiLrgk9wwtS8JEZ6FzGYMSeKLrTkMSmr5gwDM6NzGuGZyH6eo33xmfx77wnSOvvGbU5yTfrjzxGUjCA3057evNl6zpjkaE3RwpU2WVtU6B09tPFDo/C12HC5h+uPLnBOYu1NZUz+E8+nmQ4zsEcMNC9axIauQi9KTG3yQ2u0aDR5jFoSOiYi64JP8cdqAem0Xjk7m/BHdWlW8qzVcP7WvU9Tdp+ULDrCRGBnMvvwyhnSPok9COKvuOoP9BWWs3FPgFN+RPaJZt/8ofzsvzVmREuDfs0cyqX8CK3bn89W2w7y1uuHYujvuo2G3Hipiq1XP7D/f7ml0n2tfX+tcduTZ//61tQzsGuGczLum1uXJ55VUsmbfEc4e3JVfv/IjS7fnStimEyCiLnQovCHoL181hvAgf2dYYrg1KnbJTZPYll3ErBHdeezzHTzx5U66RpoQTpfIYLpEBvPjXjP71O9O781PVqnj1LgwesWHMXtsCt2jQzlnmOkXmD6kK6f0iuXbnXkcLPSsXfP+HyYw66nvjus6HHVvwHNuWnev/qttOezMKWbW8O6k3/8FAPOvGM3S7fU7hY+WVTH7+ZU8cdkI+ie2rFO8Lt/vzuMXz6/k05tOa1WoTTh2RNSFk57JAzxHwMZYI1oHdI1wxupvOrMf10zuQ3CAZw69I1pRW6u55/zB3PfRFk7tE8fSWyc3eK6YsEA+vXkSo/7+uTP+HeCniAk15/S3qUbj4u788az+PNpAHn1zOPbJLa50tl393zXOZbtdc8vbG9hfUMYaa7rEaY8t4+pJvTlUWMGV43vib7NRUV1Lr/gwwoL8CbMym7YcLKJHXCgFJVW8umIvd84cxOKN5hVjZUZBPVF3zMu79VARlTW1jO4Z2+rraQ2llTVoaFUmVmvILCijpLKm1eHBtqZFpXdbS3p6ul69+thjjoLQUfhxbwEXP7uC+VeMdnbetgS7XfPW6ky6RYfQKz6MiGB/Rvz9c4Z2j+KS9GQC/W3c/j9TeGxC3zi+2+VZ02btX89i84FCM0jqi9aLe1ty2ZgU/jClL6c9vNSj/fObJ3HWY2ZMQFpSJJMHJJAaF8bF6clkF1Vw6oNfkZYUyZZD5g3HEfrZfKCQfonhrNl3hNd/2M8Tl41gT14pPeJCCfSz8c/PtnPe8G4M7BrpLK8cHODHX97bRLC/n3NcQ136//kTaux2Mh5sPsRUU2vHz6aazDSqy53vbOLzLdms/stZLd7HnbYqvSueuiAcB2NSY1l/91lEh7auXo3NprhsbA+PtscvHcH4PnF0sUI8IYH+9IoLo8Zu57td33PN5D7EhAYwpFsUsWGBTOqfwKT+CVw3tS/vrTvALW9vaPR8d80cyNfbc/l+d9Pljo+FhT9msvDHzHrtv7QmHQfYcqjIKd5Hy6ucqaSONoA739lIZEgAz32TgZ9NOT35cb1j+ev7ZmrEZy8fxVNLd/PeuoN8d8dU7v1wC2+s2s+bV4/jtR/2W9c6iM+2ZJNfWsXPRnSnoLSKL7cepsptLIDWmg82HGR4cjSHCiv4aONBHrhgKGD6GtLv/4I7ZwzkV+NTPd7OHv1sO326hNcr7FZVY+dwUYXX3gJag3jqgtBJSL1jMQBvXj2OTzZns3pfAecO68ZbqzNZfP1p5JdW8vX2XP7ynplD/vFLR7ByTwELVu2vd6xAP5uHCPoil6anOMcguKOUq/hcQzwzZxSbDxY6B4Y5eOCCIUzql0BmQZlH5tOdMwby4Cfb+OzmSUyz3jz+ccFQHvp0G387L43U+DDmPL+Scisk1VjorTnaylMXUReETsKevFJKK2sY0r3puvcfbzpEdEgA4/vGU1pZw6YDhXyw4SB/mNKX0AA/th8uZmDXCNZnHuXKl35s8BgPXDCE/yzfQ0ad0at/OnuAR+bOycixZhCJqAuC4HUqqmvxsyn++dl2vt+Vz9yJqUSFBDB1YCJgJh8J8rcxpHuUs8MUYOm2HPp2CWfF7nzSukWy8Mf9XDm+F/kllVw6/wdG94zhrpkDufCZFfx8ZHfeWXfAue8/LhiKUnDwaDn/bmBS8lun9eeLrTmszzxab11b4WdTnDM0qV69oZYgoi4IwklFWVUN/jYbgf42SiprCA/y5+DRcg4eLScpOoTudQqpaa0pqazh3H9/y8Wjk7luaj+01mzIKmThqv1cOiaFhIggZw7/ucO6sfNwMfM+3Ua/LuEEB/ixfGceAHsenMmGrEJe/HYPad0iOa1fPIO7RVFr1zy3bDf/W5PF7txSrpvSl1vPHsD+/DJ+8cIPZB0pd9rz8Q2nMfPJ5QC8c+14lu3IpbrWzlNLdzMgMYIlN9evcNoSRNQFQRC8wLr9RxjaPcqj1HN1rZ1HP9/Bryf2Ij48iPfXHyAkwM8j4ym3uJLgABsRwQ3PGdAcIuqCIAidiLYSdSm9KwiC0IkQURcEQehEiKgLgiB0IkTUBUEQOhEi6oIgCJ0IEXVBEIROhIi6IAhCJ0JEXRAEoRPhlcFHSqlcYN8x7h4P5LWhOe1JZ7oWkOvxdeR6fJvmrqen1jrheE/iFVE/HpRSq9tiVJUv0JmuBeR6fB25Ht/mRF2PhF8EQRA6ESLqgiAInQhfFPX57W1AG9KZrgXkenwduR7f5oRcj8/F1AVBEIRjxxc9dUEQBOEY8RlRV0pNV0ptV0rtUkrd0d72uKOUSlFKLVVKbVFK/aSUutFqj1VKfa6U2ml9xljtSin1pHUtG5VSo9yO9Str+51KqV+5tY9WSm2y9nlSKaW8fE1+Sql1SqmPrO+9lFIrrfO/qZQKtNqDrO+7rPWpbse402rfrpQ62639hN5LpVS0UmqRUmqbUmqrUurUDn5vbrb+zjYrpRYopYI72v1RSr2olMpRSm12a/P6PWnsHF64lkesv7eNSql3lVLRbuta9bsfy71tEq11u/8D/IDdQG8gENgApLW3XW72JQGjrOUIYAeQBjwM3GG13wE8ZC3PBD4BFDAOWGm1xwIZ1meMtRxjrVtlbausfWd4+Zr+CLwBfGR9fwu4zFp+FrjGWr4WeNZavgx401pOs+5TENDLun9+7XEvgVeA31jLgUB0R703QHdgDxDidl+u7Gj3B5gEjAI2u7V5/Z40dg4vXMs0wN9afsjtWlr9u7f23jZrrzf/s7XiRzsVWOL2/U7gzva2qwl73wfOArYDSVZbErDdWn4OmO22/XZr/WzgObf256y2JGCbW7vHdl6wPxn4EpgKfGT9x8hz+yN13g9gCXCqtexvbafq3iPHdif6XgJRGBFUddo76r3pDmRihMzfuj9nd8T7AxzneqEAAALzSURBVKTiKYRevyeNnaOtr6XOuguA1xv6PZv73Y/l/15ztvpK+MXxh+wgy2rzOaxXoJHASiBRa33IWpUNJFrLjV1PU+1ZDbR7i8eB2wC79T0OOKq1rmng/E6brfWF1vatvUZv0QvIBV5SJpz0glIqjA56b7TWB4B/AvuBQ5jfew0d9/64cyLuSWPn8CZzMW8L0PprOZb/e03iK6LeIVBKhQP/A27SWhe5r9PmcerzqURKqXOBHK31mva2pY3wx7waP6O1HgmUYl67nXSUewNgxYBnYR5W3YAwYHq7GuUFTsQ9ORHnUEr9GagBXvfmeVqDr4j6ASDF7Xuy1eYzKKUCMIL+utb6Hav5sFIqyVqfBORY7Y1dT1PtyQ20e4MJwPlKqb3AQkwI5gkgWinl38D5nTZb66OAfFp/jd4iC8jSWq+0vi/CiHxHvDcAZwJ7tNa5Wutq4B3MPeuo98edE3FPGjtHm6OUuhI4F5hjPUBoxuaG2vNp/b1tGm/E0o4hXuWP6QTphasTYXB72+VmnwJeBR6v0/4Inp0yD1vL5+DZ8bPKao/FxH9jrH97gFhrXd2On5kn4Lom4+oofRvPzpprreU/4NlZ85a1PBjPDqEMTGfQCb+XwHJggLV8j3VfOuS9AU4BfgJCrfO9AlzfEe8P9WPqXr8njZ3DC9cyHdgCJNTZrtW/e2vvbbO2evM/Wyt/tJmYrJLdwJ/b2546tk3EvMZtBNZb/2Zi4ltfAjuBL9z+4BTwlHUtm4B0t2PNBXZZ/65ya08HNlv7/B8t6BBpg+uajEvUe1v/UXZZf2RBVnuw9X2Xtb632/5/tuzdjltGyIm+l8AIYLV1f96zBKDD3hvgXmCbdc7/WgLRoe4PsADTJ1CNeZv69Ym4J42dwwvXsgsT73bowbPH+rsfy71t6p+MKBUEQehE+EpMXRAEQWgDRNQFQRA6ESLqgiAInQgRdUEQhE6EiLogCEInQkRdEAShEyGiLgiC0IkQURcEQehE/D/6z+Bvz2jKqQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete - train time: 33847s, best epoch: 282, best loss: 1.207718, best accuracy: 70.12%\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.utils.plot import Ploter\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "import os\n",
    "\n",
    "epoch_num = 300   # 训练周期，取值一般为[1,300]\n",
    "train_batch = 128 # 训练批次，取值一般为[1,256]\n",
    "valid_batch = 128 # 验证批次，取值一般为[1,256]\n",
    "displays = 100    # 显示迭代\n",
    "\n",
    "start_lr = 0.00001                         # 开始学习率，取值一般为[1e-8,5e-1]\n",
    "based_lr = 0.1                             # 基础学习率，取值一般为[1e-8,5e-1]\n",
    "epoch_iters = math.ceil(50000/train_batch) # 每轮迭代数\n",
    "warmup_iter = 10 * epoch_iters             # 预热迭代数，取值一般为[1,10]\n",
    "\n",
    "momentum = 0.9     # 优化器动量\n",
    "l2_decay = 0.00005 # 正则化系数，取值一般为[1e-5,5e-4]\n",
    "epsilon = 0.05     # 标签平滑率，取值一般为[1e-2,1e-1]\n",
    "\n",
    "checkpoint = False                   # 断点标识\n",
    "model_path = './work/out/hs-resnet'  # 模型路径\n",
    "result_txt = './work/out/result.txt' # 结果文件\n",
    "class_num  = 100                     # 类别数量\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 准备数据\n",
    "    train_reader = paddle.batch(\n",
    "        reader=paddle.reader.shuffle(reader=paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "        batch_size=train_batch)\n",
    "    \n",
    "    valid_reader = paddle.batch(\n",
    "        reader=paddle.dataset.cifar.test100(),\n",
    "        batch_size=valid_batch)\n",
    "    \n",
    "    # 声明模型\n",
    "    model = ResNet()\n",
    "    \n",
    "    # 优化算法\n",
    "    consine_lr = fluid.layers.cosine_decay(based_lr, epoch_iters, epoch_num) # 余弦衰减策略\n",
    "    decayed_lr = fluid.layers.linear_lr_warmup(consine_lr, warmup_iter, start_lr, based_lr) # 线性预热策略\n",
    "    \n",
    "    optimizer = fluid.optimizer.Momentum(\n",
    "        learning_rate=decayed_lr,                           # 衰减学习策略\n",
    "        momentum=momentum,                                  # 优化动量系数\n",
    "        regularization=fluid.regularizer.L2Decay(l2_decay), # 正则衰减系数\n",
    "        parameter_list=model.parameters())\n",
    "    \n",
    "    # 加载断点\n",
    "    if checkpoint: # 是否加载断点文件\n",
    "        model_dict, optimizer_dict = fluid.load_dygraph(model_path) # 加载断点参数\n",
    "        model.set_dict(model_dict)                                  # 设置权重参数\n",
    "        optimizer.set_dict(optimizer_dict)                          # 设置优化参数\n",
    "    else:          # 否则删除结果文件\n",
    "        if os.path.exists(result_txt): # 如果存在结果文件\n",
    "            os.remove(result_txt)      # 那么删除结果文件\n",
    "    \n",
    "    # 初始训练\n",
    "    avg_train_loss = 0 # 平均训练损失\n",
    "    avg_valid_loss = 0 # 平均验证损失\n",
    "    avg_valid_accu = 0 # 平均验证精度\n",
    "    \n",
    "    iterator = 1                                # 迭代次数\n",
    "    train_prompt = \"Train loss\"                 # 训练标签\n",
    "    valid_prompt = \"Valid loss\"                 # 验证标签\n",
    "    ploter = Ploter(train_prompt, valid_prompt) # 训练图像\n",
    "    \n",
    "    best_epoch = 0           # 最好周期\n",
    "    best_accu = 0            # 最好精度\n",
    "    best_loss = 100.0        # 最好损失\n",
    "    train_time = time.time() # 训练时间\n",
    "    \n",
    "    # 开始训练\n",
    "    for epoch_id in range(epoch_num):\n",
    "        # 训练模型\n",
    "        model.train() # 设置训练\n",
    "        for batch_id, train_data in enumerate(train_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = train_augment(image_data)                                                        # 使用数据增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "\n",
    "            label_data = np.array([x[1] for x in train_data]).astype(np.int64)                        # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                             # 转换数据类型\n",
    "            label = fluid.layers.label_smooth(label=fluid.one_hot(label, class_num), epsilon=epsilon) # 使用标签平滑\n",
    "            label.stop_gradient = True                                                                # 停止梯度传播\n",
    "\n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label, soft_label=True)\n",
    "            train_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            # 反向传播\n",
    "            train_loss.backward()\n",
    "            optimizer.minimize(train_loss)\n",
    "            model.clear_gradients()\n",
    "            \n",
    "            # 显示结果\n",
    "            if iterator % displays == 0:\n",
    "                # 显示图像\n",
    "                avg_train_loss = train_loss.numpy()[0]                # 设置训练损失\n",
    "                ploter.append(train_prompt, iterator, avg_train_loss) # 添加训练图像\n",
    "                ploter.plot()                                         # 显示训练图像\n",
    "                \n",
    "                # 打印结果\n",
    "                print(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\".format(\n",
    "                    iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "                \n",
    "                # 写入文件\n",
    "                with open(result_txt, 'a') as file:\n",
    "                    file.write(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\\n\".format(\n",
    "                        iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "            \n",
    "            # 增加迭代\n",
    "            iterator += 1\n",
    "            \n",
    "        # 验证模型\n",
    "        valid_loss_list = [] # 验证损失列表\n",
    "        valid_accu_list = [] # 验证精度列表\n",
    "        \n",
    "        model.eval()   # 设置验证\n",
    "        for batch_id, valid_data in enumerate(valid_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = valid_augment(image_data)                                                        # 使用图像增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "            \n",
    "            label_data = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64) # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                       # 转换数据类型\n",
    "            label.stop_gradient = True                                                          # 停止梯度传播\n",
    "            \n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算精度\n",
    "            valid_accu = fluid.layers.accuracy(infer,label)\n",
    "            \n",
    "            valid_accu_list.append(valid_accu.numpy())\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label)\n",
    "            valid_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            valid_loss_list.append(valid_loss.numpy())\n",
    "        \n",
    "        # 设置结果\n",
    "        avg_valid_accu = np.mean(valid_accu_list)             # 设置验证精度\n",
    "        \n",
    "        avg_valid_loss = np.mean(valid_loss_list)             # 设置验证损失\n",
    "        ploter.append(valid_prompt, iterator, avg_valid_loss) # 添加训练图像\n",
    "        \n",
    "        # 保存模型\n",
    "        fluid.save_dygraph(model.state_dict(), model_path)     # 保存权重参数\n",
    "        fluid.save_dygraph(optimizer.state_dict(), model_path) # 保存优化参数\n",
    "        \n",
    "        if avg_valid_loss < best_loss:\n",
    "            fluid.save_dygraph(model.state_dict(), model_path + '-best') # 保存权重\n",
    "            \n",
    "            best_epoch = epoch_id + 1                                    # 更新迭代\n",
    "            best_accu = avg_valid_accu                                   # 更新精度\n",
    "            best_loss = avg_valid_loss                                   # 更新损失\n",
    "    \n",
    "    # 显示结果\n",
    "    train_time = time.time() - train_time # 设置训练时间\n",
    "    print('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}'.format(\n",
    "        train_time, best_epoch, best_loss, best_accu))\n",
    "    \n",
    "    # 写入文件\n",
    "    with open(result_txt, 'a') as file:\n",
    "        file.write('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}\\n'.format(\n",
    "            train_time, best_epoch, best_loss, best_accu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer time: 0.154507s, infer value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_path = './work/out/img.png' # 图片路径\n",
    "model_path = './work/out/hs-resnet-best' # 模型路径\n",
    "\n",
    "# 加载图像\n",
    "def load_image(image_path):\n",
    "    \"\"\"\n",
    "    功能:\n",
    "        读取图像并转换到输入格式\n",
    "    输入:\n",
    "        image_path - 输入图像路径\n",
    "    输出:\n",
    "        image - 输出图像\n",
    "    \"\"\"\n",
    "    # 读取图像\n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    \n",
    "    # 转换格式\n",
    "    image = image.resize((32, 32), Image.ANTIALIAS) # 调整图像大小\n",
    "    image = np.array(image, dtype=np.float32) # 转换数据格式，数据类型转换为float32\n",
    "\n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    image = (image/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    image = image.transpose((2, 0, 1)).astype(np.float32) # 数据格式从HWC转换为CHW，数据类型转换为float32\n",
    "    \n",
    "    # 增加维度\n",
    "    image = np.expand_dims(image, axis=0) # 增加数据维度\n",
    "    \n",
    "    return image\n",
    "\n",
    "# 预测图像\n",
    "with fluid.dygraph.guard():\n",
    "    # 读取图像\n",
    "    image = load_image(image_path)\n",
    "    image = fluid.dygraph.to_variable(image)\n",
    "    \n",
    "    # 加载模型\n",
    "    model = ResNet()                               # 加载模型\n",
    "    model_dict, _ = fluid.load_dygraph(model_path) # 加载权重\n",
    "    model.set_dict(model_dict)                     # 设置权重\n",
    "    model.eval()                                   # 设置验证\n",
    "    \n",
    "    # 前向传播\n",
    "    infer_time = time.time()              # 推断开始时间\n",
    "    infer = model(image)\n",
    "    infer_time = time.time() - infer_time # 推断结束时间\n",
    "    \n",
    "    # 显示结果\n",
    "    vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "             'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "             'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "             'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "             'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "             'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "             'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "             'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "             'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "             'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "             'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "             'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "             'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "             'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "             'baby', 'boy', 'girl', 'man', 'woman',\n",
    "             'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "             'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "             'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "             'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "             'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "    vlist.sort() # 字母上升排序\n",
    "    print('infer time: {:f}s, infer value: {}'.format(infer_time, vlist[np.argmax(infer.numpy())]) )\n",
    "    \n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    plt.figure(figsize=(3, 3))     # 设置显示大小\n",
    "    plt.imshow(image)              # 设置显示图像\n",
    "    plt.show()                     # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.8.4 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
